{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "emnist_inference_cnn.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "VltpCBUpxzZA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torchvision import datasets, transforms\n",
        "import torchvision\n",
        "import matplotlib.pyplot as plt\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import logging\n",
        "import numpy as np\n",
        "import random\n",
        "import time"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TkbyCxh0FmXK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 74
        },
        "outputId": "050b391d-f3e9-4f03-de4b-addece52c796"
      },
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2\n",
        "\n",
        "%matplotlib inline\n",
        "\n",
        "# Link to google drive and download data\n",
        "\n",
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive/')\n"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount(\"/content/gdrive/\", force_remount=True).\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ou8ZLK7u7vZu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Define the Net\n",
        "\n",
        "\n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Net, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(28, 64, (5,5), padding = 2)\n",
        "        self.conv1_bn = nn.BatchNorm2d(64)\n",
        "\n",
        "        self.conv2 = nn.Conv2d(64, 128, 2, padding = 2)\n",
        "\n",
        "        self.fc1 = nn.Linear(2048, 1024)\n",
        "\n",
        "        self.dropout = nn.Dropout(0.3)\n",
        "\n",
        "        self.fc2 = nn.Linear(1024, 512)\n",
        "\n",
        "        self.bn = nn.BatchNorm1d(1)\n",
        "\n",
        "        self.fc3 = nn.Linear(512, 128)\n",
        "\n",
        "        self.fc4 = nn.Linear(128,47)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = F.relu(self.conv1(x))\n",
        "        x = F.max_pool2d(x, 2, 2)\n",
        "        x = self.conv1_bn(x)\n",
        "\n",
        "        x = F.relu(self.conv2(x))\n",
        "        x = F.max_pool2d(x, 2, 2)\n",
        "\n",
        "\n",
        "        x = x.view(-1, 2048)\n",
        "        x = F.relu(self.fc1(x))\n",
        "\n",
        "        x = self.dropout(x)\n",
        "\n",
        "        x = self.fc2(x)\n",
        "\n",
        "        x = x.view(-1, 1, 512)\n",
        "        x = self.bn(x)\n",
        "\n",
        "        x = x.view(-1, 512)\n",
        "        x = self.fc3(x)\n",
        "        x = self.fc4(x)\n",
        "\n",
        "        #return F.log_softmax(x, dim=1)\n",
        "        \n",
        "        return x"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CFbtWnra8p-t",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "root = ''\n",
        "\n",
        "#test data tranform\n",
        "transform_valid = transforms.Compose(\n",
        "    [\n",
        "     transforms.ToTensor(),\n",
        "     \n",
        "    ])\n",
        "emnist_test = datasets.EMNIST(root,split = 'balanced', train=False, download=True, transform = transform_valid)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EtVmK-RxNMKj",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 233
        },
        "outputId": "9fb046d0-5ebe-4092-bc53-e8c6548d227e"
      },
      "source": [
        "# Download wights\n",
        "\n",
        "!wget https://imadelhanafi.com/data/draft/cnn_weights_blog.pth"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2019-07-28 10:51:27--  https://imadelhanafi.com/data/draft/cnn_weights_blog.pth\n",
            "Resolving imadelhanafi.com (imadelhanafi.com)... 104.28.28.48, 104.28.29.48, 2606:4700:30::681c:1d30, ...\n",
            "Connecting to imadelhanafi.com (imadelhanafi.com)|104.28.28.48|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 11093914 (11M) [application/octet-stream]\n",
            "Saving to: ‘cnn_weights_blog.pth’\n",
            "\n",
            "\rcnn_weights_blog.pt   0%[                    ]       0  --.-KB/s               \rcnn_weights_blog.pt  66%[============>       ]   7.07M  35.3MB/s               \rcnn_weights_blog.pt 100%[===================>]  10.58M  46.2MB/s    in 0.2s    \n",
            "\n",
            "2019-07-28 10:51:28 (46.2 MB/s) - ‘cnn_weights_blog.pth’ saved [11093914/11093914]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mwy7oKlSJ9Kj",
        "colab_type": "code",
        "outputId": "5e43bab6-2e5a-4f0c-b1f3-2c98b6e9da2f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 354
        }
      },
      "source": [
        "# Download weights\n",
        "\n",
        "net = Net()\n",
        "\n",
        "model_weights = 'cnn_weights_blog.pth'\n",
        "net.load_state_dict(torch.load(model_weights)[\"state_dict\"])\n",
        "net.eval()\n",
        "\n",
        "# Send to GPU\n",
        "\n",
        "device = torch.device('cuda:0')\n",
        "net.to(device)\n",
        "\n",
        "random_img = 13\n",
        "\n",
        "class_mapping = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabdefghnqrt'\n",
        "\n",
        "print('Inference on GPU')\n",
        "start = time.time()\n",
        "with torch.no_grad():\n",
        "\n",
        "  data = emnist_test.data[random_img] # (1,28,28)\n",
        "  \n",
        "  data = data.type(torch.FloatTensor)\n",
        "  data = data/255\n",
        "  \n",
        "  data = data.view(1, 28, 28, 1).to(device)\n",
        "  data = torch.transpose(data, 1, 2)\n",
        "  \n",
        "  out = net(data)\n",
        "  probabilities = F.softmax(out, dim = 1)\n",
        "  pred_y = torch.max(probabilities, 1)[1].cpu().data.numpy()\n",
        "  \n",
        "  print(class_mapping[int(pred_y)], 'prediction')\n",
        "  print(torch.max(probabilities, 1)[0], 'probability')\n",
        "  plt.imshow(data.cpu().reshape([28, 28]), cmap='Greys_r')\n",
        "  print(class_mapping[int(emnist_test.targets[random_img].numpy())], 'real value')\n",
        "  plt.show()\n",
        "\n",
        "end = time.time()\n",
        "print(\"inference time on GPU: \", end-start)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Inference on GPU\n",
            "G prediction\n",
            "tensor([0.9985], device='cuda:0') probability\n",
            "G real value\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEGVJREFUeJzt3W+MVGWWx/HfEUT+TIuyhJY4uCjg\nmgmKYxo0prORjBg1k6Av/MMLg8kIJgruxDGRsC8kJppxs+NEE2JE/ANk1lEzYxQ1y7hkDbNRCaAi\nKNogMhmwoUcBxxFxBM++6KvTat9z27pVdat5vp+k01V16qk6FP3rW9XPvfcxdxeA9BxXdQMAqkH4\ngUQRfiBRhB9IFOEHEkX4gUQRfiBRhB9IFOEHEjW0mU9mZuxOCDSYu9tA7ldqy29ml5rZu2a2w8wW\nlXksAM1lte7bb2ZDJHVJmiVpt6QNkua4+9vBGLb8QIM1Y8s/Q9IOd9/p7n+X9FtJs0s8HoAmKhP+\nUyX9uc/13dlt32Bm881so5ltLPFcAOqs4X/wc/dlkpZJvO0HWkmZLf8eSRP6XP9hdhuAQaBM+DdI\nmmJmp5vZMEnXSnq2Pm0BaLSa3/a7+xEzWyBpjaQhkh5x97fq1hmAhqp5qq+mJ+MzP9BwTdnJB8Dg\nRfiBRBF+IFGEH0gU4QcSRfiBRDX1eH60nqFD4x+BsWPHhvX29vawPn369NzatGnTwrGbN28O66tX\nrw7rPT09uTVWqmLLDySL8AOJIvxAogg/kCjCDySK8AOJ4qi+Y8DIkSNza6eddlo49vrrrw/rF110\nUVifNGlSWB8xYkRubfjw4eHYw4cPh/VXXnklrN988825ta6urnDsYMZRfQBChB9IFOEHEkX4gUQR\nfiBRhB9IFOEHEsU8fxNEc91SfNirJF1zzTVh/dprr82ttbW1hWOLDukt68svv6x5rNmApqtzrV+/\nPrfW2dkZji3Td9WY5wcQIvxAogg/kCjCDySK8AOJIvxAogg/kKhSk7xmtkvSJ5KOSjri7h31aGqw\nOeuss8L6ggULwvqcOXPCeiPn6ov28yiq79u3L6y/9NJLubWDBw+GY8eNGxfWr7zyyrA+efLk3Npx\nx8XbvcE8zz9Q9djDY6a7f1iHxwHQRLztBxJVNvwu6Q9mtsnM5tejIQDNUfZtf6e77zGzcZJeNLN3\n3H1d3ztkvxT4xQC0mFJbfnffk33vkfS0pBn93GeZu3ek+sdAoFXVHH4zG2VmbV9dlnSJpK31agxA\nY5V5298u6enssMuhkv7L3f+7Ll0BaLiaw+/uOyXFaywfQy688MLc2po1a8Kxo0aNKvXcRXPtBw4c\nyK1t3749HLt8+fKwHh0TL0k7duwI65999llYjxx//PFhfenSpWF95syZNT93CpjqAxJF+IFEEX4g\nUYQfSBThBxJF+IFENfa8zYPIlClTwvp9992XWys7lXfo0KGwvmXLlrB+++2359aKlqLu6ekJ61Ue\n2vrFF1+E9QcffDCsb9u2Lbd25MiRmno6lrDlBxJF+IFEEX4gUYQfSBThBxJF+IFEEX4gUcks0V10\neu2VK1eG9Y6O/BMRFc2F7927N6xfcsklYf3dd98N60ePHg3rqYqW+G7mz32zsUQ3gBDhBxJF+IFE\nEX4gUYQfSBThBxJF+IFEHTPH80dzupK0cOHCsD5tWu1nIS+ax7/xxhvDenTcuXRsz0k3Eq9bjC0/\nkCjCDySK8AOJIvxAogg/kCjCDySK8AOJKjye38wekfRTST3uPjW7bYykJyRNlLRL0tXunr9O9D8e\nq9TE63HH5f+uuu6668KxDz30UFgfOjTe5SE6t/6MGTPCsVXO4w8fPjysjx49Oqzv378/rE+cODGs\nn3DCCbm1Dz/8MBz70UcfhfWi8/pHTjnllLA+bNiwsB79uyRp9+7dYb3M0uVF6nk8/2OSLv3WbYsk\nrXX3KZLWZtcBDCKF4Xf3dZK+/et/tqQV2eUVkq6oc18AGqzWz/zt7t6dXd4rqb1O/QBoktL79ru7\nR5/lzWy+pPllnwdAfdW65d9nZuMlKfueu9qjuy9z9w53zz8DJoCmqzX8z0qam12eK+mZ+rQDoFkK\nw29mj0t6RdK/mNluM/uZpF9KmmVm2yVdnF0HMIgUfuZ39zk5pZ/UuZdC0Zz1rFmzwrFF8/hFtmzZ\nklsrOq9+o48rj16XefPmhWPnzMn77+0V/bsl6dJLvz0L/E0jR47Mrb333nvh2E2bNoX1p556KqxH\n7r777rDe1tYW1k888cSwvnr16rB+//3359a6urrCsfXCHn5Aogg/kCjCDySK8AOJIvxAogg/kKhB\ntUT3rbfemlu78847w7HRlJMUH7IrxUt0v/POO+HYRps6dWpube3ateHYsWPHhvWiU6JXqWhp9Eh0\neHg9FOUqmh4+++yzw7FFS7KzRDeAEOEHEkX4gUQRfiBRhB9IFOEHEkX4gUS11BLdRYfd3nLLLbm1\nonn8onnXrVu3hvXt27eH9Sp9/vnnubWdO3eGY3ft2hXWTz/99LBeZj+Ak046KawPGTKkVD36Py/6\neSi7/0vRqbnXrFmTWyuax68XtvxAogg/kCjCDySK8AOJIvxAogg/kCjCDySqpeb5i4waNarmsUXz\nrsuXLw/rzZp7rUW0D0JnZ2c4tmg+u+xx79F+AKtWrQrHXnXVVaWeO/q3LV68OBz7/PPPl3ruouXH\ne3pyF7lqGrb8QKIIP5Aowg8kivADiSL8QKIIP5Aowg8kqnCe38wekfRTST3uPjW7bYmkeZL+kt1t\nsbu/0Kgm66Fonn/Dhg1N6qS5yu6fUObc+JI0YsSI3NrkyZNLPXaRgwcP5tZWrlwZju3u7q53Oy1n\nIFv+xyT1twj7r9393OyrpYMP4LsKw+/u6yTtb0IvAJqozGf+BWb2ppk9YmYn160jAE1Ra/gfkDRJ\n0rmSuiX9Ku+OZjbfzDaa2cYanwtAA9QUfnff5+5H3f1LSQ9JmhHcd5m7d7h7/kqXAJqupvCb2fg+\nV6+UFJ/6FkDLGchU3+OSLpI01sx2S7pD0kVmdq4kl7RL0o0N7BFAAxSG393n9HPzww3oBYNQ0VoL\nd9xxR27tvPPOK/Xchw4dCuuLFi3KraUwj1+EPfyARBF+IFGEH0gU4QcSRfiBRBF+IFGD6tTdZUSH\nlkrS+eefH9Y3b95cz3ZaRtGpuYvqZ555Zli/+OKLv3dPA/Xqq6+G9eeee65hz30sYMsPJIrwA4ki\n/ECiCD+QKMIPJIrwA4ki/ECiWmqev+g00dFS1GPGjAnHjhw5Mqzfc889YT3y8ssv1zy2Hi644ILc\n2rRp08Kx06dPD+tnnHFGWC9aNr1o/4rIBx98ENZnz54d1j/99NOanzsFbPmBRBF+IFGEH0gU4QcS\nRfiBRBF+IFGEH0jUoJrnf/TRR3Nr55xzTji2aJ5/9OjRYf3ee+/NrRUt/91o0Vz68OHDSz120fH8\nZRSdejt6zSXm8ctiyw8kivADiSL8QKIIP5Aowg8kivADiSL8QKIK5/nNbIKklZLaJbmkZe5+n5mN\nkfSEpImSdkm62t0PNK5VadWqVbm1ovnsJUuWhPW2trawHs2lF+1DUKRo/4Yihw8fzq0VHRO/c+fO\nsB6dK0CShg0bFtYjjz32WFh/4IEHan5sFBvIlv+IpF+4+48kXSDpZjP7kaRFkta6+xRJa7PrAAaJ\nwvC7e7e7v5Zd/kTSNkmnSpotaUV2txWSrmhUkwDq73t95jeziZJ+LGm9pHZ3785Ke9X7sQDAIDHg\nffvN7AeSfifp5+7+VzP7uububmaeM26+pPllGwVQXwPa8pvZ8eoN/m/c/ffZzfvMbHxWHy+pp7+x\n7r7M3TvcvaMeDQOoj8LwW+8m/mFJ29y972FWz0qam12eK+mZ+rcHoFHMvd936/+4g1mnpD9K2iLp\nqzmpxer93P+kpNMk/Um9U337Cx4rfrISiqb6ZsyYEdYvu+yysD5z5szc2qRJk8KxRYeuvvDCC2G9\naCrw9ddfz60VTeUtXbo0rBctwd33419/Dh48mFsrmkbs6uoK6+ifu8f/KZnCz/zu/n+S8h7sJ9+n\nKQCtgz38gEQRfiBRhB9IFOEHEkX4gUQRfiBRLXXq7jKiw1olad26daXq0Smsy57e+siRI6XGR/s4\n3HTTTeHYsvP4Bw7ER3HfdtttuTXm8avFlh9IFOEHEkX4gUQRfiBRhB9IFOEHEkX4gUQdM/P8jRYd\nU1/21NtlTZgwIbe2cOHCcGzRPH6RJ554Iqw/+eSTpR4fjcOWH0gU4QcSRfiBRBF+IFGEH0gU4QcS\nRfiBRBWet7+uT9bA8/anLFo+/K677grH3nDDDWH9448/DutF6yF0d3eHddTfQM/bz5YfSBThBxJF\n+IFEEX4gUYQfSBThBxJF+IFEFc7zm9kESSsltUtyScvc/T4zWyJpnqS/ZHdd7O7hQvPM8zffuHHj\nwvrMmTPD+vvvvx/WN2zYENabuR8Jeg10nn8gJ/M4IukX7v6ambVJ2mRmL2a1X7v7f9baJIDqFIbf\n3bsldWeXPzGzbZJObXRjABrre33mN7OJkn4saX120wIze9PMHjGzk3PGzDezjWa2sVSnAOpqwOE3\nsx9I+p2kn7v7XyU9IGmSpHPV+87gV/2Nc/dl7t7h7h116BdAnQwo/GZ2vHqD/xt3/70kufs+dz/q\n7l9KekhSfIQHgJZSGH7rPb3rw5K2ufu9fW4f3+duV0raWv/2ADTKQKb6OiX9UdIWSV+do3qxpDnq\nfcvvknZJujH742D0WMz7AA020Kk+jucHjjEczw8gRPiBRBF+IFGEH0gU4QcSRfiBRBF+IFGEH0gU\n4QcSRfiBRBF+IFGEH0gU4QcSRfiBRA3k7L319KGkP/W5Pja7rRW1am+t2pdEb7WqZ2//PNA7NvV4\n/u88udnGVj23X6v21qp9SfRWq6p6420/kCjCDySq6vAvq/j5I63aW6v2JdFbrSrprdLP/ACqU/WW\nH0BFKgm/mV1qZu+a2Q4zW1RFD3nMbJeZbTGzN6peYixbBq3HzLb2uW2Mmb1oZtuz7/0uk1ZRb0vM\nbE/22r1hZpdX1NsEM/tfM3vbzN4ys3/Lbq/0tQv6quR1a/rbfjMbIqlL0ixJuyVtkDTH3d9uaiM5\nzGyXpA53r3xO2Mz+VdLfJK1096nZbf8hab+7/zL7xXmyu9/eIr0tkfS3qlduzhaUGd93ZWlJV0i6\nXhW+dkFfV6uC162KLf8MSTvcfae7/13SbyXNrqCPlufu6yTt/9bNsyWtyC6vUO8PT9Pl9NYS3L3b\n3V/LLn8i6auVpSt97YK+KlFF+E+V9Oc+13ertZb8dkl/MLNNZja/6mb60d5nZaS9ktqrbKYfhSs3\nN9O3VpZumdeulhWv640/+H1Xp7ufJ+kySTdnb29bkvd+Zmul6ZoBrdzcLP2sLP21Kl+7Wle8rrcq\nwr9H0oQ+13+Y3dYS3H1P9r1H0tNqvdWH9321SGr2vafifr7WSis397eytFrgtWulFa+rCP8GSVPM\n7HQzGybpWknPVtDHd5jZqOwPMTKzUZIuUeutPvyspLnZ5bmSnqmwl29olZWb81aWVsWvXcuteO3u\nTf+SdLl6/+L/nqR/r6KHnL7OkLQ5+3qr6t4kPa7et4FfqPdvIz+T9E+S1kraLul/JI1pod5WqXc1\n5zfVG7TxFfXWqd639G9KeiP7urzq1y7oq5LXjT38gETxBz8gUYQfSBThBxJF+IFEEX4gUYQfSBTh\nBxJF+IFE/T/GJWizcDg1bwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "inference time on GPU:  0.17154407501220703\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z_4v8KuKMMwo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}